define a simple spline¶
In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
def bspline_basis_matrix():
"""Returns the cubic B-spline basis matrix (4x4)"""
return (1.0 / 6.0) * torch.tensor([
[-1, 3, -3, 1],
[ 3, -6, 3, 0],
[-3, 0, 3, 0],
[ 1, 4, 1, 0]
])
class ClampedBSplineTrajectoryOptimizer:
def __init__(self, start, goal, num_internal_ctrl_pts=6, dim=2, lr=0.05, device='cpu'):
"""
Uses clamped B-spline with repeated endpoints to ensure path starts at `start` and ends at `goal`.
Args:
start, goal: Tensors of shape (dim,)
num_internal_ctrl_pts: Number of learnable internal control points
"""
self.device = device
self.dim = dim
self.start = start.to(device)
self.goal = goal.to(device)
# Internal control points (learnable)
internal = torch.linspace(0, 1, num_internal_ctrl_pts + 2, device=device).unsqueeze(1)
internal = internal[1:-1] # exclude endpoints
internal_ctrl = internal * (self.goal - self.start) + self.start
eps = 0.05 * torch.randn_like(internal_ctrl)
internal_ctrl += eps
self.internal_ctrl_pts = nn.Parameter(internal_ctrl)
self.lr = lr
self.optimizer = optim.Adam([self.internal_ctrl_pts], lr=lr)
self.basis = bspline_basis_matrix().to(device)
def get_full_ctrl_pts(self):
# Repeat start and goal 3 times for cubic clamping
return torch.cat([
self.start.expand(3, -1),
self.internal_ctrl_pts,
self.goal.expand(3, -1)
], dim=0)
def evaluate_spline(self, resolution=100, stochastic=False):
"""
Returns interpolated points along the clamped B-spline.
If `stochastic` is True, samples random u values instead of fixed linspace.
"""
ctrl_pts = self.get_full_ctrl_pts()
segments = ctrl_pts.shape[0] - 3
samples_per_segment = resolution // segments
points = []
for i in range(segments - 1):
G = ctrl_pts[i:i + 4] # 4 control points
if stochastic:
u_vals = torch.rand(samples_per_segment, device=self.device)
else:
u_vals = torch.linspace(0, 1, samples_per_segment, device=self.device)
U = torch.stack([u_vals**3, u_vals**2, u_vals, torch.ones_like(u_vals)], dim=1) # (N, 4)
segment_points = (U @ self.basis) @ G # (N, dim)
points.append(segment_points)
return torch.cat(points, dim=0) # shape (total_samples, dim)
def sdf_cost(sdf_values, alpha=35.0, beta=0.1):
"""
Converts SDF values into cost values.
sdf_values: Tensor of signed distances at each query point
alpha: Controls how sharp the penalty is inside the obstacle
beta: Controls how far the cost decays outside the obstacle
Returns:
cost: Tensor of same shape, with high values inside obstacle and decaying outside
"""
cost = torch.where(
sdf_values < 0.25,
torch.exp(-alpha * sdf_values), # inside: large cost
beta * torch.exp(-sdf_values / beta) # outside: decaying cost
)
return cost
def spacing_regularizer(ctrl_pts, strength=1.0):
"""
Penalizes non-uniform spacing between consecutive control points.
"""
diffs = ctrl_pts[1:] - ctrl_pts[:-1]
dists = torch.norm(diffs, dim=1)
mean_dist = dists.mean()
return strength * ((dists - mean_dist) ** 2).mean()
def repulsion_cost(ctrl_pts, min_dist=0.2, strength=1.0):
"""
Penalizes control points that get too close to each other.
Uses a log barrier on distances below `min_dist`.
ctrl_pts: Tensor of shape (N, D)
"""
N = ctrl_pts.shape[0]
# Create indices for all pairs of points
i, j = torch.triu_indices(N, N, offset=1)
# Compute distances between all pairs
dists = torch.norm(ctrl_pts[i] - ctrl_pts[j], dim=1)
# Apply log barrier only to distances below min_dist
mask = dists < min_dist
cost = (-torch.log(dists[mask] / min_dist + 1e-6)).sum()
return strength * cost
In [30]:
import numpy as np
import plotly.graph_objects as go
def plot_sdf_grid_plotly_2d(env, grid_resolution=100, name="2D SDF Visualization"):
"""
Plots the SDF grid of the environment using plotly in 2D.
Args:
env: An instance of EnvBase or its subclass.
grid_resolution: Number of points per axis.
"""
# Define grid
x = torch.linspace(env.limits_np[0][0], env.limits_np[1][0], grid_resolution, **env.tensor_args)
y = torch.linspace(env.limits_np[0][1], env.limits_np[1][1], grid_resolution, **env.tensor_args)
X, Y = torch.meshgrid(x, y, indexing='ij')
points = torch.stack([X, Y], dim=-1).view(-1, 1, 2)
# Compute SDF
sdf = env.compute_sdf(points).view(grid_resolution, grid_resolution)
sdf = sdf_cost(sdf)
sdf_np = sdf.cpu().numpy()
# Plot with plotly
fig = go.Figure(
data=go.Contour(
z=sdf_np,
x=x.cpu().numpy(),
y=y.cpu().numpy(),
colorscale='Viridis',
# colorbar=dict(title='SDF'),
colorbar=None,
line_width=0,
contours=dict(
coloring ='heatmap',
showlabels = False, # show labels on contours
labelfont = dict( # label font properties
size = 12,
color = 'white',
)
)
)
)
fig.update_layout(
xaxis_title='X',
yaxis_title='Y',
title=name,
yaxis_scaleanchor="x",
yaxis_scaleratio=1
)
# fig.show()
return fig
setup planning a sdf environment via spline¶
In [31]:
def plan_env(env2d, name="env"):
# Visualize SDF grid in 2D
fig = plot_sdf_grid_plotly_2d(env2d, grid_resolution=100, name=name)
spline = ClampedBSplineTrajectoryOptimizer(
start=torch.tensor([-1.0, -1.0]),
goal=torch.tensor([1.0, 1.0]),
num_internal_ctrl_pts=16,
dim=2,
lr=0.002,
device=env2d.tensor_args['device'],
)
# plot initial path
initial_spline_points = spline.evaluate_spline().detach().cpu()
fig.add_trace(go.Scatter(x=initial_spline_points[:, 0], y=initial_spline_points[:, 1], mode='lines', name='Initial Path', line=dict(color='blue')))
# ctrl_np = spline.get_full_ctrl_pts().detach().cpu().numpy()
# fig.add_trace(go.Scatter(x=ctrl_np[:, 0], y=ctrl_np[:, 1], mode='markers', name='Control Points', line=dict(color='black', alpha=0.5)))
with tqdm.trange(200) as t:
for i in t:
spline.optimizer.zero_grad()
spline_points = spline.evaluate_spline(resolution=100, stochastic=True)
# print(spline_points.shape)
# print(env2d.compute_sdf(spline_points))
# exit()
sdfs = env2d.compute_sdf(spline_points)
obs_cost = sdf_cost(sdfs).sum()
even_spline_points = spline.evaluate_spline(resolution=100, stochastic=False)
len_cost = torch.norm(even_spline_points[1:] - even_spline_points[:-1], dim=1).sum()
ctrl_pts = spline.get_full_ctrl_pts()
spacing_cost = spacing_regularizer(ctrl_pts, strength=30)
cost = (5 * obs_cost + 2 * len_cost + spacing_cost).sum()
cost.backward()
spline.optimizer.step()
t.set_postfix(obs_cost=obs_cost.item(), len_cost=len_cost.item(), spacing_cost=spacing_cost.item())
# plot final path
final_spline_points = spline.evaluate_spline().detach().cpu()
fig.add_trace(go.Scatter(x=final_spline_points[:, 0], y=final_spline_points[:, 1], mode='lines', name='Final Path', line=dict(color='red')))
ctrl_np = spline.get_full_ctrl_pts().detach().cpu().numpy()
fig.add_trace(go.Scatter(x=ctrl_np[:, 0], y=ctrl_np[:, 1], mode='markers', name='Control Points', line=dict(color='rgba(255, 125, 0, 0.5)')))
fig.show()
In [ ]:
test each environment¶
In [32]:
from torch_robotics.torch_utils.torch_utils import DEFAULT_TENSOR_ARGS
# 2D environments
from torch_robotics.environments.env_circle_2d import EnvCircle2D
from torch_robotics.environments.env_dense_2d import EnvDense2D
from torch_robotics.environments.env_dense_2d_extra_objects import EnvDense2DExtraObjects
from torch_robotics.environments.env_grid_circles_2d import EnvGridCircles2D
from torch_robotics.environments.env_narrow_passage_dense_2d import EnvNarrowPassageDense2D
from torch_robotics.environments.env_narrow_passage_dense_2d_extra_objects import EnvNarrowPassageDense2DExtraObjects
from torch_robotics.environments.env_planar2link import EnvPlanar2Link
from torch_robotics.environments.env_simple_2d import EnvSimple2D
from torch_robotics.environments.env_simple_2d_extra_objects import EnvSimple2DExtraObjects
from torch_robotics.environments.env_square_2d import EnvSquare2D
# 3D environments
from torch_robotics.environments.env_maze_boxes_3d import EnvMazeBoxes3D
from torch_robotics.environments.env_spheres_3d import EnvSpheres3D
from torch_robotics.environments.env_spheres_3d_extra_objects import EnvSpheres3DExtraObjects
from torch_robotics.environments.env_table_shelf import EnvTableShelf
# List of (name, class, is3d)
envs = [
# 2D
("EnvCircle2D", EnvCircle2D, False),
("EnvDense2D", EnvDense2D, False),
("EnvDense2DExtraObjects", EnvDense2DExtraObjects, False),
("EnvGridCircles2D", EnvGridCircles2D, False),
("EnvNarrowPassageDense2D", EnvNarrowPassageDense2D, False),
("EnvNarrowPassageDense2DExtraObjects", EnvNarrowPassageDense2DExtraObjects, False),
("EnvPlanar2Link", EnvPlanar2Link, False),
("EnvSimple2D", EnvSimple2D, False),
("EnvSimple2DExtraObjects", EnvSimple2DExtraObjects, False),
("EnvSquare2D", EnvSquare2D, False),
# 3D
("EnvMazeBoxes3D", EnvMazeBoxes3D, True),
("EnvSpheres3D", EnvSpheres3D, True),
("EnvSpheres3DExtraObjects", EnvSpheres3DExtraObjects, True),
("EnvTableShelf", EnvTableShelf, True),
]
i = 0
for name, cls, is3d in envs:
print(f"Plotting {name} ({'3D' if is3d else '2D'}) ...")
try:
env = cls(precompute_sdf_obj_fixed=True, sdf_cell_size=0.01, tensor_args=DEFAULT_TENSOR_ARGS)
if is3d:
pass
# plot_sdf_grid_plotly(env, grid_resolution=30, name=name)
else:
plan_env(env, name=name)
# plot_sdf_grid_plotly_2d(env, grid_resolution=100, name=name)
except Exception as e:
print(f"Failed to plot {name}: {e}")
Plotting EnvCircle2D (2D) ... Precomputing the SDF grid and gradients took: 0.044 sec
0%| | 0/200 [00:00<?, ?it/s, len_cost=2.95, obs_cost=4e+4, spacing_cost=0.235]
100%|██████████| 200/200 [00:03<00:00, 64.36it/s, len_cost=2.85, obs_cost=587, spacing_cost=0.248]
Plotting EnvDense2D (2D) ... Precomputing the SDF grid and gradients took: 0.036 sec
100%|██████████| 200/200 [00:02<00:00, 69.87it/s, len_cost=3.23, obs_cost=17.8, spacing_cost=0.519]
Plotting EnvDense2DExtraObjects (2D) ... Precomputing the SDF grid and gradients took: 0.033 sec
100%|██████████| 200/200 [00:03<00:00, 55.90it/s, len_cost=3.13, obs_cost=27.8, spacing_cost=0.577]
Plotting EnvGridCircles2D (2D) ... Precomputing the SDF grid and gradients took: 0.037 sec
100%|██████████| 200/200 [00:03<00:00, 60.96it/s, len_cost=3.76, obs_cost=25.7, spacing_cost=0.499]
Plotting EnvNarrowPassageDense2D (2D) ... Precomputing the SDF grid and gradients took: 0.033 sec
100%|██████████| 200/200 [00:03<00:00, 59.21it/s, len_cost=3.26, obs_cost=2.56, spacing_cost=0.486]
Plotting EnvNarrowPassageDense2DExtraObjects (2D) ... Precomputing the SDF grid and gradients took: 0.030 sec
100%|██████████| 200/200 [00:03<00:00, 51.55it/s, len_cost=3.15, obs_cost=5.62, spacing_cost=0.421]
Plotting EnvPlanar2Link (2D) ... Precomputing the SDF grid and gradients took: 0.017 sec
100%|██████████| 200/200 [00:03<00:00, 57.98it/s, len_cost=3.03, obs_cost=633, spacing_cost=0.474]
Plotting EnvSimple2D (2D) ... Precomputing the SDF grid and gradients took: 0.024 sec
100%|██████████| 200/200 [00:03<00:00, 55.31it/s, len_cost=2.99, obs_cost=2.44, spacing_cost=0.421]
Plotting EnvSimple2DExtraObjects (2D) ... Precomputing the SDF grid and gradients took: 0.020 sec
100%|██████████| 200/200 [00:03<00:00, 55.90it/s, len_cost=3.49, obs_cost=6.48, spacing_cost=0.602]
Plotting EnvSquare2D (2D) ... Precomputing the SDF grid and gradients took: 0.030 sec
100%|██████████| 200/200 [00:03<00:00, 59.37it/s, len_cost=2.92, obs_cost=4.19e+5, spacing_cost=0.349]
Plotting EnvMazeBoxes3D (3D) ... Precomputing the SDF grid and gradients took: 0.657 sec Plotting EnvSpheres3D (3D) ... Precomputing the SDF grid and gradients took: 0.117 sec Plotting EnvSpheres3DExtraObjects (3D) ... Precomputing the SDF grid and gradients took: 0.118 sec Plotting EnvTableShelf (3D) ... Precomputing the SDF grid and gradients took: 0.524 sec
In [ ]:
In [ ]:
In [ ]:
In [ ]: